{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(sphx_glr_how_to_compile_models_from_pytorch)=\n",
    "# 编译 PyTorch 模型\n",
    "\n",
    "**Author**: [Alex Wong](https://github.com/alexwong/)\n",
    "\n",
    "本文是使用 Relay 部署 PyTorch 模型的入门教程。\n",
    "\n",
    "对于我们来说，首先应该安装 PyTorch。TorchVision 也是必需的，因为我们将使用它作为我们的模型动物园。\n",
    "\n",
    "快速的解决方案是通过 pip 进行安装：\n",
    "\n",
    "```python\n",
    "pip install torch torchvision\n",
    "```\n",
    "\n",
    "或者参考[官方网站](https://pytorch.org/get-started/locally/)。\n",
    "\n",
    "PyTorch 版本应该向后兼容，但应该与适当的 TorchVision 版本一起使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# PyTorch imports\n",
    "import torch\n",
    "import torchvision\n",
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.contrib.download import download_testdata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 载入 PyTorch 预训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "model_name = \"resnet18\"\n",
    "model = getattr(torchvision.models, model_name)(pretrained=True)\n",
    "model = model.eval()\n",
    "\n",
    "# 我们通过跟踪获取 TorchScripted 模型\n",
    "input_shape = [1, 3, 224, 224]\n",
    "input_data = torch.randn(input_shape)\n",
    "scripted_model = torch.jit.trace(model, input_data).eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载测试图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "img_url = \"https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true\"\n",
    "img_path = download_testdata(img_url, \"cat.png\", module=\"data\")\n",
    "img = Image.open(img_path).resize((224, 224))\n",
    "\n",
    "# Preprocess the image and convert to tensor\n",
    "from torchvision import transforms\n",
    "\n",
    "my_preprocess = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "img = my_preprocess(img)\n",
    "img = np.expand_dims(img, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入 Graph 到 Relay\n",
    "\n",
    "将 PyTorch 图转换为 Relay 图。`input_name` 可以是任意的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_name = \"input0\"\n",
    "shape_list = [(input_name, img.shape)]\n",
    "mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Relay 构建\n",
    "\n",
    "使用给定的输入规范将 graph 编译为 llvm 目标："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = tvm.target.Target(\"llvm\", host=\"llvm\")\n",
    "dev = tvm.cpu(0)\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    lib = relay.build(mod, target=target, params=params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 在 TVM 上执行可移植 Graph\n",
    "\n",
    "可以尝试在目标上部署编译后的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.contrib import graph_executor\n",
    "\n",
    "dtype = \"float32\"\n",
    "m = graph_executor.GraphModule(lib[\"default\"](dev))\n",
    "# Set inputs\n",
    "m.set_input(input_name, tvm.nd.array(img.astype(dtype)))\n",
    "# Execute\n",
    "m.run()\n",
    "# Get outputs\n",
    "tvm_output = m.get_output(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 查找 synset 名称\n",
    "\n",
    "在 1000 类 synset 中查找预测 top 1 索引。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Relay top-1 id: 281, class name: tabby, tabby cat\n",
      "Torch top-1 id: 281, class name: tabby, tabby cat\n"
     ]
    }
   ],
   "source": [
    "synset_url = \"\".join(\n",
    "    [\n",
    "        \"https://raw.githubusercontent.com/Cadene/\",\n",
    "        \"pretrained-models.pytorch/master/data/\",\n",
    "        \"imagenet_synsets.txt\",\n",
    "    ]\n",
    ")\n",
    "synset_name = \"imagenet_synsets.txt\"\n",
    "synset_path = download_testdata(synset_url, synset_name, module=\"data\")\n",
    "with open(synset_path) as f:\n",
    "    synsets = f.readlines()\n",
    "\n",
    "synsets = [x.strip() for x in synsets]\n",
    "splits = [line.split(\" \") for line in synsets]\n",
    "key_to_classname = {spl[0]: \" \".join(spl[1:]) for spl in splits}\n",
    "\n",
    "class_url = \"\".join(\n",
    "    [\n",
    "        \"https://raw.githubusercontent.com/Cadene/\",\n",
    "        \"pretrained-models.pytorch/master/data/\",\n",
    "        \"imagenet_classes.txt\",\n",
    "    ]\n",
    ")\n",
    "class_name = \"imagenet_classes.txt\"\n",
    "class_path = download_testdata(class_url, class_name, module=\"data\")\n",
    "with open(class_path) as f:\n",
    "    class_id_to_key = f.readlines()\n",
    "\n",
    "class_id_to_key = [x.strip() for x in class_id_to_key]\n",
    "\n",
    "# Get top-1 result for TVM\n",
    "top1_tvm = np.argmax(tvm_output.numpy()[0])\n",
    "tvm_class_key = class_id_to_key[top1_tvm]\n",
    "\n",
    "# Convert input to PyTorch variable and get PyTorch result for comparison\n",
    "with torch.no_grad():\n",
    "    torch_img = torch.from_numpy(img)\n",
    "    output = model(torch_img)\n",
    "\n",
    "    # Get top-1 result for PyTorch\n",
    "    top1_torch = np.argmax(output.numpy())\n",
    "    torch_class_key = class_id_to_key[top1_torch]\n",
    "\n",
    "print(f\"Relay top-1 id: {top1_tvm}, class name: {key_to_classname[tvm_class_key]}\")\n",
    "print(f\"Torch top-1 id: {top1_torch}, class name: {key_to_classname[torch_class_key]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "6ee5142ba8a2589df39b0df03e82f50c3ae535c49aaf7d83abad1a0d572c7e37"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('tvm-test')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
